import random
from collections import namedtuple, deque
from models.r2d2_config import sequence_length, burn_in_length, eta, n_step, gamma, over_lapping_length, device
import torch
import numpy as np

Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask', 'step', 'rnn_state'))
# Old with precomputed target
# DIAL_Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask', 'step', 'rnn_state',
#                 'other_state', 'other_hidden', 'other_action', 'other_target', 'other_mask'))
DIAL_Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask', 'step', 'rnn_state',
                'other_state', 'other_hidden', 'other_action', 'other_mask', 'communicated', 'target_reward', 'target_state', 'target_hidden', 'target_mask'))
OBL_Transition = namedtuple('Transition', ('state', 'target', 'action', 'mask', 'rnn_state'))
# MI_OBL_Transition = namedtuple('Transition', ('state', 'target', 'action', 'mask', 'mi_term_mask', 'rnn_state'))
MI_OBL_Transition = namedtuple('Transition', ('state', 'next_state', 'target', 'action', 'mask', 'mi_term_mask', 'rnn_state'))

ACTIONS = list(range(6))
LEFT, RIGHT, UP, DOWN, NOOP, SEND = ACTIONS

class LocalBuffer(object):
    def __init__(self):
        self.n_step_memory = []
        self.local_memory = []
        self.memory = []
        self.over_lapping_from_prev = []

    def push(self, state, next_state, action, reward, mask, rnn_state):
        self.n_step_memory.append([state, next_state, action, reward, mask, rnn_state])
        if len(self.n_step_memory) == n_step or mask == 0:
            [state, _, action, _, _, rnn_state] = self.n_step_memory[0]
            [_, next_state, _, _, mask, _] = self.n_step_memory[-1]

            sum_reward = 0
            for t in reversed(range(len(self.n_step_memory))):
                [_, _, _, reward, _, _] = self.n_step_memory[t]
                sum_reward += reward + gamma * sum_reward
            reward = sum_reward
            step = len(self.n_step_memory)
            self.push_local_memory(state, next_state, action, reward, mask, step, rnn_state)
            self.n_step_memory = []


    def push_local_memory(self, state, next_state, action, reward, mask, step, rnn_state):
        self.local_memory.append(Transition(state, next_state, action, reward, mask, step, torch.stack(rnn_state).view(2, -1)))
        if (len(self.local_memory) + len(self.over_lapping_from_prev)) == sequence_length or mask == 0:
            self.local_memory = self.over_lapping_from_prev + self.local_memory
            length = len(self.local_memory)
            while len(self.local_memory) < sequence_length:
                self.local_memory.append(Transition(
                    torch.zeros(state.size()).to(device),
                    torch.zeros(state.size()).to(device),
                    0,
                    0,
                    0,
                    0,
                    torch.zeros([2, 1, 16]).view(2, -1).to(device)
                ))
            self.memory.append([self.local_memory, length])
            if mask == 0:
                self.over_lapping_from_prev = []
            else:
                self.over_lapping_from_prev = self.local_memory[len(self.local_memory) - over_lapping_length:]
            self.local_memory = []

    def sample(self):
        episodes = self.memory
        batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state = [], [], [], [], [], [], []
        lengths = []
        for episode, length in episodes:
            batch = Transition(*zip(*episode))
            batch_state.append(torch.stack(list(batch.state)))
            batch_next_state.append(torch.stack(list(batch.next_state)))
            batch_action.append(torch.Tensor(list(batch.action)))
            batch_reward.append(torch.Tensor(list(batch.reward)))
            batch_mask.append(torch.Tensor(list(batch.mask)))
            batch_step.append(torch.Tensor(list(batch.step)))
            batch_rnn_state.append(torch.stack(list(batch.rnn_state)))

            lengths.append(length)
        self.memory = []
        return Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state), lengths

class DialLocalBuffer(object):
    def __init__(self):
        self.n_step_memory = []
        self.local_memory = []
        self.memory = []
        self.over_lapping_from_prev = []
        # Flag to indicate whether they communicated within an episode
        self.communicated_in_episode = False

    def push(self, state, next_state, action, reward, mask, rnn_state, other_state, other_hidden, other_action, other_mask, target_reward, target_state, target_hidden, target_mask, both_in_booth_flag):
        self.n_step_memory.append([state, next_state, action, reward, mask, rnn_state, other_state, other_hidden, other_action, other_mask, target_reward, target_state, target_hidden, target_mask])
        if len(self.n_step_memory) == n_step or mask == 0:
            [state, _, action, _, _, rnn_state, _, _, _, _, _, _, _, _] = self.n_step_memory[0]
            [_, next_state, _, _, mask, _, _, _, _, _, _, _, _, _] = self.n_step_memory[-1]

            sum_reward = 0
            for t in reversed(range(len(self.n_step_memory))):
                [_, _, _, reward, _, _, _, _, _, _, _, _, _, _] = self.n_step_memory[t]
                sum_reward += reward + gamma * sum_reward
            reward = sum_reward
            step = len(self.n_step_memory)
            self.push_local_memory(state, next_state, action, reward, mask, step, rnn_state, other_state, other_hidden, other_action, other_mask, target_reward, target_state, target_hidden, target_mask, both_in_booth_flag)
            self.n_step_memory = []


    # def push_local_memory(self, state, next_state, action, reward, mask, step, rnn_state, other_state, other_hidden, other_action, other_target, other_mask, both_in_booth_flag):
    def push_local_memory(self, state, next_state, action, reward, mask, step, rnn_state, other_state, other_hidden, other_action, other_mask, target_reward, target_state, target_hidden, target_mask, both_in_booth_flag):
        communicated = (both_in_booth_flag and action == SEND)
        self.local_memory.append(DIAL_Transition(state, next_state, action, reward, mask, step, torch.stack(rnn_state).view(2, -1), other_state, torch.stack(other_hidden).view(2, -1), other_action, other_mask, communicated, target_reward, target_state, torch.stack(target_hidden).view(2, -1), target_mask))
        self.communicated_in_episode = self.communicated_in_episode or communicated

        # Only add if both are in booth and action is communicative - be careful of the conditions here
        # Right now we are ending at the receiver's receiving of message, not  the whole of the receiver's trajectory, might be useful to do the whole
        # if(both_in_booth_flag and (action == HINT_UP or action == HINT_DOWN)):

        if(self.communicated_in_episode and (((len(self.local_memory) + len(self.over_lapping_from_prev)) == sequence_length) or (mask == 0 or other_mask == 0))):
            self.local_memory = self.over_lapping_from_prev + self.local_memory
            length = len(self.local_memory)
            while len(self.local_memory) < sequence_length:
                self.local_memory.append(DIAL_Transition(
                    torch.zeros(state.size()).to(device),
                    torch.zeros(state.size()).to(device),
                    0,
                    0,
                    0,
                    0,
                    torch.zeros([2, 1, 16]).view(2, -1).to(device),
                    torch.zeros(other_state.size()).to(device),
                    torch.zeros([2, 1, 16]).view(2, -1).to(device),
                    0,
                    0,
                    0,
                    0,
                    torch.zeros(other_state.size()).to(device),
                    torch.zeros([2, 1, 16]).view(2, -1).to(device),
                    0
                ))
            self.memory.append([self.local_memory, length])
            if(mask == 0 or other_mask == 0):
                self.over_lapping_from_prev = []
                self.communicated_in_episode = False
            else:
                if(length < over_lapping_length):
                    self.over_lapping_from_prev = self.local_memory[:length]
                else:
                    self.over_lapping_from_prev = self.local_memory[length - over_lapping_length : length]
            self.local_memory = []
        if(mask == 0 or other_mask == 0):
            self.over_lapping_from_prev = []
            self.local_memory = []
            self.communicated_in_episode = False

    def sample(self):
        episodes = self.memory
        batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state = [], [], [], [], [], [], []
        batch_other_state, batch_other_hidden, batch_other_action, batch_other_mask = [], [], [], []
        batch_communicated, batch_target_reward, batch_target_state, batch_target_hidden, batch_target_mask = [], [], [], [], []
        lengths = []
        for episode, length in episodes:
            batch = DIAL_Transition(*zip(*episode))
            batch_state.append(torch.stack(list(batch.state)))
            batch_next_state.append(torch.stack(list(batch.next_state)))
            batch_action.append(torch.Tensor(list(batch.action)))
            batch_reward.append(torch.Tensor(list(batch.reward)))
            batch_mask.append(torch.Tensor(list(batch.mask)))
            batch_step.append(torch.Tensor(list(batch.step)))
            batch_rnn_state.append(torch.stack(list(batch.rnn_state)))

            batch_other_state.append(torch.stack(list(batch.other_state)))
            batch_other_hidden.append(torch.stack(list(batch.other_hidden)))
            batch_other_action.append(torch.Tensor(list(batch.other_action)))
            # batch_other_target.append(torch.Tensor(list(batch.other_target)))
            batch_other_mask.append(torch.Tensor(list(batch.other_mask)))

            batch_communicated.append(torch.Tensor(list(batch.communicated)))

            batch_target_reward.append(torch.Tensor(list(batch.target_reward)))
            batch_target_state.append(torch.stack(list(batch.target_state)))
            batch_target_hidden.append(torch.stack(list(batch.target_hidden)))
            batch_target_mask.append(torch.Tensor(list(batch.target_mask)))

            lengths.append(length)
        self.memory = []
        return DIAL_Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state, batch_other_state, batch_other_hidden, batch_other_action, batch_other_mask, batch_communicated, batch_target_reward, batch_target_state, batch_target_hidden, batch_target_mask), lengths

    def clear_local(self):
        self.n_step_memory = []
        self.local_memory = []
        self.over_lapping_from_prev = []

class OBLLocalBuffer(object):
    def __init__(self):
        self.local_memory = []
        self.memory = []
        self.over_lapping_from_prev = []


    def push(self, state, target, action, mask, rnn_state):
        self.local_memory.append(OBL_Transition(state, target, action, mask, torch.stack(rnn_state).view(2, -1)))
        if (len(self.local_memory) + len(self.over_lapping_from_prev)) == sequence_length or mask == 0:
            self.local_memory = self.over_lapping_from_prev + self.local_memory
            length = len(self.local_memory)
            while len(self.local_memory) < sequence_length:
                self.local_memory.append(OBL_Transition(
                    torch.zeros(state.size()).to(device),
                    0,
                    0,
                    0,
                    torch.zeros([2, 1, 16]).view(2, -1).to(device)
                ))
            self.memory.append([self.local_memory, length])
            if mask == 0:
                self.over_lapping_from_prev = []
            else:
                self.over_lapping_from_prev = self.local_memory[len(self.local_memory) - over_lapping_length:]
            self.local_memory = []

    def sample(self):
        episodes = self.memory
        batch_state, batch_target, batch_action, batch_mask, batch_rnn_state = [], [], [], [], []
        lengths = []
        for episode, length in episodes:
            batch = OBL_Transition(*zip(*episode))
            batch_state.append(torch.stack(list(batch.state)))
            batch_target.append(torch.Tensor(list(batch.target)))
            batch_action.append(torch.Tensor(list(batch.action)))
            batch_mask.append(torch.Tensor(list(batch.mask)))
            batch_rnn_state.append(torch.stack(list(batch.rnn_state)))

            lengths.append(length)
        self.memory = []
        return OBL_Transition(batch_state, batch_target, batch_action, batch_mask, batch_rnn_state), lengths

class MIOBLLocalBuffer(object):
    def __init__(self):
        self.local_memory = []
        self.memory = []
        self.over_lapping_from_prev = []


    def push(self, state, next_state, target, action, mask, mi_term_mask, rnn_state):
        self.local_memory.append(MI_OBL_Transition(state, next_state, target, action, mask, mi_term_mask, torch.stack(rnn_state).view(2, -1)))
        if (len(self.local_memory) + len(self.over_lapping_from_prev)) == sequence_length or mask == 0:
            self.local_memory = self.over_lapping_from_prev + self.local_memory
            length = len(self.local_memory)
            while len(self.local_memory) < sequence_length:
                self.local_memory.append(MI_OBL_Transition(
                    torch.zeros(state.size()).to(device),
                    torch.zeros(next_state.size()).to(device),
                    0,
                    0,
                    0,
                    # 0,
                    torch.zeros((1, 6)).to(device),
                    torch.zeros([2, 1, 16]).view(2, -1).to(device)
                ))
            self.memory.append([self.local_memory, length])
            if mask == 0:
                self.over_lapping_from_prev = []
            else:
                self.over_lapping_from_prev = self.local_memory[len(self.local_memory) - over_lapping_length:]
            self.local_memory = []

    def sample(self):
        episodes = self.memory
        batch_state, batch_next_state, batch_target, batch_action, batch_mask, batch_mi_term_mask, batch_rnn_state = [], [], [], [], [], [], []
        lengths = []
        for episode, length in episodes:
            batch = MI_OBL_Transition(*zip(*episode))
            batch_state.append(torch.stack(list(batch.state)))
            batch_next_state.append(torch.stack(list(batch.next_state)))
            batch_target.append(torch.Tensor(list(batch.target)))
            batch_action.append(torch.Tensor(list(batch.action)))
            batch_mask.append(torch.Tensor(list(batch.mask)))
            batch_mi_term_mask.append(list(batch.mi_term_mask))
            # batch_mi_term_mask.append(torch.Tensor(list(batch.mi_term_mask)))
            batch_rnn_state.append(torch.stack(list(batch.rnn_state)))

            lengths.append(length)
        self.memory = []
        return MI_OBL_Transition(batch_state, batch_next_state, batch_target, batch_action, batch_mask, batch_mi_term_mask, batch_rnn_state), lengths

class Memory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)
        self.memory_probability = deque(maxlen=capacity)

    def td_error_to_prior(self, td_error, lengths):
        abs_td_error_sum  = td_error.abs().sum(dim=1, keepdim=True).view(-1).detach().numpy()
        lengths_burn = [length - burn_in_length for length in lengths]

        prior_max = td_error.abs().max(dim=1, keepdim=True)[0].view(-1).detach().numpy()

        prior_mean = abs_td_error_sum / lengths_burn
        prior = eta * prior_max + (1 - eta) * prior_mean
        return prior

    def push(self, td_error, batch, lengths):
        # batch.state[local_mini_batch, sequence_length, item]
        prior = self.td_error_to_prior(td_error, lengths)

        for i in range(len(batch)):
            self.memory.append([Transition(batch.state[i], batch.next_state[i], batch.action[i], batch.reward[i], batch.mask[i], batch.step[i], batch.rnn_state[i]), lengths[i]])
            self.memory_probability.append(prior[i])

    def sample(self, batch_size):
        probability = np.array(self.memory_probability)
        probability = probability / probability.sum()

        indexes = np.random.choice(range(len(self.memory_probability)), batch_size, p=probability)
        # indexes = np.random.choice(range(len(self.memory_probability)), batch_size)
        episodes = [self.memory[idx][0] for idx in indexes]
        lengths = [self.memory[idx][1] for idx in indexes]

        batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state = [], [], [], [], [], [], []
        for episode in episodes:
            batch_state.append(episode.state)
            batch_next_state.append(episode.next_state)
            batch_action.append(episode.action)
            batch_reward.append(episode.reward)
            batch_mask.append(episode.mask)
            batch_step.append(episode.step)
            batch_rnn_state.append(episode.rnn_state)

        return Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state), indexes, lengths

    def update_prior(self, indexes, td_error, lengths):
        prior = self.td_error_to_prior(td_error, lengths)
        priors_idx = 0
        for idx in indexes:
            self.memory_probability[idx] = prior[priors_idx]
            priors_idx += 1

    def __len__(self):
        return len(self.memory)

class DIALRandomMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)
        self.memory_probability = deque(maxlen=capacity)


    def push(self, td_error, batch, lengths):
        prior = self.td_error_to_prior(td_error, lengths)
        for i in range(len(lengths)):
            self.memory.append([DIAL_Transition(batch.state[i], batch.next_state[i], batch.action[i], batch.reward[i], batch.mask[i], batch.step[i], batch.rnn_state[i],
            batch.other_state[i], batch.other_hidden[i], batch.other_action[i], batch.other_mask[i], batch.communicated[i],
            batch.target_reward[i], batch.target_state[i], batch.target_hidden[i], batch.target_mask[i]), lengths[i]])
            # Equal probability
            self.memory_probability.append(prior[i])

    def td_error_to_prior(self, td_error, lengths):
        abs_td_error_sum  = td_error.abs().sum(dim=1, keepdim=True).view(-1).detach().numpy()
        lengths_burn = [length - burn_in_length for length in lengths]

        prior_max = td_error.abs().max(dim=1, keepdim=True)[0].view(-1).detach().numpy()

        prior_mean = abs_td_error_sum / lengths_burn
        prior = eta * prior_max + (1 - eta) * prior_mean
        return prior

    def sample(self, batch_size):
        probability = np.array(self.memory_probability)
        probability = probability / probability.sum()

        indexes = np.random.choice(range(len(self.memory_probability)), batch_size, p=probability)
        # indexes = np.random.choice(range(len(self.memory_probability)), batch_size)
        episodes = [self.memory[idx][0] for idx in indexes]
        lengths = [self.memory[idx][1] for idx in indexes]

        batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state = [], [], [], [], [], [], []
        batch_other_state, batch_other_hidden, batch_other_action, batch_other_mask = [], [], [], []
        batch_communicated, batch_target_reward, batch_target_state, batch_target_hidden, batch_target_mask = [], [], [], [], []
        for episode in episodes:
            batch_state.append(episode.state)
            batch_next_state.append(episode.next_state)
            batch_action.append(episode.action)
            batch_reward.append(episode.reward)
            batch_mask.append(episode.mask)
            batch_step.append(episode.step)
            batch_rnn_state.append(episode.rnn_state)

            batch_other_state.append(episode.other_state)
            batch_other_hidden.append(episode.other_hidden)
            batch_other_action.append(episode.other_action)
            # batch_other_target.append(episode.other_target)
            batch_other_mask.append(episode.other_mask)

            batch_communicated.append(episode.communicated)
            batch_target_reward.append(episode.target_reward)
            batch_target_state.append(episode.target_state)
            batch_target_hidden.append(episode.target_hidden)
            batch_target_mask.append(episode.target_mask)

        return DIAL_Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_step, batch_rnn_state, batch_other_state, batch_other_hidden, batch_other_action, batch_other_mask, batch_communicated, batch_target_reward, batch_target_state, batch_target_hidden, batch_target_mask), indexes, lengths

    def update_prior(self, indexes, td_error, lengths):
        prior = self.td_error_to_prior(td_error, lengths)
        priors_idx = 0
        for idx in indexes:
            self.memory_probability[idx] = prior[priors_idx]
            priors_idx += 1

    def __len__(self):
        return len(self.memory)

class OBLMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)
        self.memory_probability = deque(maxlen=capacity)

    def td_error_to_prior(self, td_error, lengths):
        abs_td_error_sum  = td_error.abs().sum(dim=1, keepdim=True).view(-1).detach().numpy()
        lengths_burn = [length - burn_in_length for length in lengths]

        prior_max = td_error.abs().max(dim=1, keepdim=True)[0].view(-1).detach().numpy()

        prior_mean = abs_td_error_sum / lengths_burn
        prior = eta * prior_max + (1 - eta) * prior_mean
        return prior

    def push(self, td_error, batch, lengths):
        # batch.state[local_mini_batch, sequence_length, item]
        prior = self.td_error_to_prior(td_error, lengths)

        for i in range(len(batch)):
            self.memory.append([OBL_Transition(batch.state[i], batch.target[i], batch.action[i], batch.mask[i], batch.rnn_state[i]), lengths[i]])
            self.memory_probability.append(prior[i])

    def sample(self, batch_size):
        probability = np.array(self.memory_probability)
        probability = probability / probability.sum()

        indexes = np.random.choice(range(len(self.memory_probability)), batch_size, p=probability)
        # indexes = np.random.choice(range(len(self.memory_probability)), batch_size)
        episodes = [self.memory[idx][0] for idx in indexes]
        lengths = [self.memory[idx][1] for idx in indexes]

        batch_state, batch_target, batch_action, batch_mask, batch_rnn_state = [], [], [], [], []
        for episode in episodes:
            batch_state.append(episode.state)
            batch_target.append(episode.target)
            batch_action.append(episode.action)
            batch_mask.append(episode.mask)
            batch_rnn_state.append(episode.rnn_state)

        return OBL_Transition(batch_state, batch_target, batch_action, batch_mask, batch_rnn_state), indexes, lengths

    def update_prior(self, indexes, td_error, lengths):
        prior = self.td_error_to_prior(td_error, lengths)
        priors_idx = 0
        for idx in indexes:
            self.memory_probability[idx] = prior[priors_idx]
            priors_idx += 1

    def __len__(self):
        return len(self.memory)

class MIOBLMemory(OBLMemory):
    def __init__(self, capacity):
        OBLMemory.__init__(self, capacity)

    def push(self, td_error, batch, lengths):
        # batch.state[local_mini_batch, sequence_length, item]
        prior = self.td_error_to_prior(td_error, lengths)

        for i in range(len(batch)):
            self.memory.append([MI_OBL_Transition(batch.state[i], batch.next_state[i], batch.target[i], batch.action[i], batch.mask[i], batch.mi_term_mask[i], batch.rnn_state[i]), lengths[i]])
            self.memory_probability.append(prior[i])

    def sample(self, batch_size):
        probability = np.array(self.memory_probability)
        probability = probability / probability.sum()

        indexes = np.random.choice(range(len(self.memory_probability)), batch_size, p=probability)
        # indexes = np.random.choice(range(len(self.memory_probability)), batch_size)
        episodes = [self.memory[idx][0] for idx in indexes]
        lengths = [self.memory[idx][1] for idx in indexes]

        batch_state, batch_next_state, batch_target, batch_action, batch_mask, batch_mi_term_mask, batch_rnn_state = [], [], [], [], [], [], []
        for episode in episodes:
            batch_state.append(episode.state)
            batch_next_state.append(episode.next_state)
            batch_target.append(episode.target)
            batch_action.append(episode.action)
            batch_mask.append(episode.mask)
            batch_mi_term_mask.append(episode.mi_term_mask)
            batch_rnn_state.append(episode.rnn_state)

        return MI_OBL_Transition(batch_state, batch_next_state, batch_target, batch_action, batch_mask, batch_mi_term_mask, batch_rnn_state), indexes, lengths
